浅谈KMP

什么是 $\mathbf{KMP}$

$KMP$ 算法是一种改进的字符串匹配算法,由$D.E.$ K $nuth$,$J.H.$ M $orris$和$V.R.$ P $ratt$提出的,因此人们称它为克努特—莫里斯—普拉特操作(简称 $KMP$ 算法)。 $KMP$ 算法的关键是利用匹配失败后的信息,尽量减少模式串与主串的匹配次数以达到快速匹配的目的。具体实现就是实现一个 $next()$ 函数,函数本身包含了模式串的局部匹配信息。时间复杂度 $\Theta(m+n)$ 。

(From 百度百科)

$\mathbf{KMP}$ 算法的思想

首先我们考虑朴素的字符串匹配。通过两个指针分别指向文本串和模式串,如果失配,那么文本串指针会指向开始匹配的下一位,而模式串的指针则会指向串首。如下图:

但是我们可以发现,我们在第一次失配前已经知道文本串的前 $5$ 个字符,那么我们就可以找到一种更优的失配转移,如下图:

我们可以发现,我们可以找到模式串下标为 $n$ 时失配的更优失配转移,当且仅当模式串下标为 $[0,n]$ 的子串中的前缀等于后缀,如下图:

显然,子串中最长前缀等于最长后缀时,可以找到最优失配转移。由此得到 $KMP$ 算法的思路:预处理模式串,得到失配数组( $fail$ 数组),再进行匹配。

如何预处理出 $\mathbf{fail}$ 数组

从上文可以看出,我们需要找到最长的前缀等于后缀。建立数组 $fail[]$ 储存失配后2指针应跳转的位置。不难想到, $fail[0]=fail[1]=0$ 因为这两个位置失配后必须回到字符串起点。接下来我们可以不停的对当前失配的位置回跳 $fail$ 指针,直到匹配或跳到字符串首。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void get_fail(char s[])
{
fail[0]=0;
fail[1]=0;
for(register int i=1;i<strlen(s);++i)
{
int t=fail[i];

while(t&&s[t]!=s[i]) t=fail[t];
if(s[t]==s[i]) fail[i+1]=t+1;
else fail[i+1]=0;
}

return;
}

同理,我们在 $kmp$ 的时候对两个字符串进行这样的操作。对比来看, $get\text{_}fail$ 操作其实就是模式串和文本串都是模式串的 $kmp$ 。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
void kmp(char s[],char t[])
{
int now=0,len=strlen(t);

get_fail(t);
for(register int i=0;i<strlen(s);++i)
{
while(now&&s[i]!=t[now]) now=fail[now];
if(s[i]==t[now]) now++;
if(now==len)
{
printf("%d\n",i-len+1);
now=fail[now];
}
}

return;
}

结语

$KMP$ 算法具有一种“最优历史处理”的性质,而这种性质也是基于 $KMP$ 的核心思想的。考虑前缀与后缀的关系使得其在匹配的时候不会反反复复地找。最后放上模板和 $std$ : P3375 【模板】KMP字符串匹配

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include <cstdio>
#include <string>
#include <iostream>
using namespace std;

string a,b;
int fail[1000005];

void get_fail(string s)
{
fail[0]=0;
fail[1]=0;
for(unsigned int i=1;i<s.length();i++)
{
int t=fail[i];

while(t&&s.at(t)!=s.at(i)) t=fail[t];
if(s.at(t)==s.at(i)) fail[i+1]=t+1;
else fail[i+1]=0;
}

return;
}

void kmp(string s1,string s2)
{
unsigned int now=0;

get_fail(s2);
for(unsigned int i=0;i<s1.length();i++)
{
while(now&&s1.at(i)!=s2.at(now)) now=fail[now];
if(s1.at(i)==s2.at(now)) now++;
if(now==s2.length())
{
printf("%d\n",i-s2.length()+1);
now=fail[now];
}
}

return;
}

int main()
{
cin>>a>>b;

a=" "+a;
kmp(a,b);

for(unsigned int i=1;i<=b.length();i++) printf("%d ",fail[i]);
return 0;
}